{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "#只是将一张图片作为例子，演示一下FCN的训练，优化以及预测的这样一个过程，单张图片来将网络进行一个测试\n",
    "from __future__ import division\n",
    "import numpy as np\n",
    "import tensorflow as tf\n",
    "import os\n",
    "import sys\n",
    "from matplotlib import pyplot as plt\n",
    "\n",
    "%matplotlib inline\n",
    "\n",
    "slim = tf.contrib.slim\n",
    "\n",
    "sys.path.append(os.path.expanduser(\"D:/python/GitHub/tensorflow/models/research/slim/\"))#需要用到VGG网络模型的定义以及需要用到的预处理，这样的过程，函数和一些变量，所以要import的path\n",
    "\n",
    "def get_kernel_size(factor):\n",
    "    return factor * 2 - factor % 2\n",
    "\n",
    "def upsample_filt(size):#获取双线性插值的filter\n",
    "    factor = (size + 1) // 2\n",
    "    if size % 2 == 1:\n",
    "        center = factor -1\n",
    "    else:\n",
    "        center = factor - 0.5\n",
    "    og = np.ogrid[:size, :size]\n",
    "    return (1 - abs(og[0] - center) / factor) * (1 - abs(og[1] - center) / factor)\n",
    "\n",
    "def bilinear_upsample_weights(factor, number_of_class):#将双线性插值得到的filter转换成我们的权重，\n",
    "                                                       #再用权重去初始化我们的kernel,\n",
    "                                                       #上面这几个是平常的标准算法，只是用python把他实现了而已\n",
    "    filter_size = get_kernel_size(factor)\n",
    "    weights = np.zeros((filter_size, filter_size, number_of_class, number_of_class), dtype=np.float32)\n",
    "    upsample_kernel = upsample_filt(filter_size)\n",
    "\n",
    "    for i in range(number_of_class):\n",
    "        weights[:,:,i,i] = upsample_kernel\n",
    "    return weights\n",
    "\n",
    "from nets import vgg\n",
    "from preprocessing import vgg_preprocessing\n",
    "\n",
    "os.environ[\"CUDA_VISIABLE_DEVICES\"] = '0'#在linux上去使用GPU的时候，将这个环境变量（这个变量可以去控制当前这个程序 \n",
    "                                         #能够看到哪一块显卡，改成空，就不使用GPU去计算，这个是通常设置跑在哪块显卡\n",
    "                                        #上的方式，也可在session config里面去指定）\n",
    "\n",
    "checkpoints_dir = \"D:/python/GitHub/tensorflow/models/research/data/pre_trained/\"\n",
    "image_filename = \"D:/python/GitHub/tensorflow/models/research/data/object_1.jpg\"\n",
    "annotation_filename = \"D:/python/GitHub/tensorflow/models/research/data/segment_1.png\"\n",
    "\n",
    "fig_size =  [15, 14]\n",
    "plt.rcParams[\"figure.figsize\"] = fig_size\n",
    "\n",
    "\n",
    "#下面进入计算图的定义阶段\n",
    "tf.reset_default_graph()#这是防止在重复运行的时候，出现变量scope冲突的问题，\n",
    "image_filename_placeholder = tf.placeholder(dtype=tf.string)#图片文件名\n",
    "annotation_filename_placeholder = tf.placeholder(dtype=tf.string)#ground truth\n",
    "is_training_placeholder = tf.placeholder(dtype=tf.bool)\n",
    "\n",
    "feed_dict_to_use = {image_filename_placeholder:image_filename,annotation_filename_placeholder:annotation_filename,is_training_placeholder:True}\n",
    "\n",
    "image_tensor = tf.read_file(image_filename_placeholder)#在计算图内部进行文件的读取，通常进行单张图片的，比较少量的测试数据的话，\n",
    "                                                     #更倾向在图外面进行这部分的读取操作，一次性读取在内存之间进行计算图的读取\n",
    "annotation_tensor = tf.read_file(annotation_filename_placeholder)#我们有时会将文件读取放到计算图内部，特别是有很多个文件需要读取\n",
    "                                                                  #或者每一调数据都代表着一个文件，然后需要随机化的输入我们的计算图\n",
    "\n",
    "image_tensor = tf.image.decode_jpeg(image_tensor, channels=3)#这两行就是decode\n",
    "annotation_tensor = tf.image.decode_png(annotation_tensor, channels=1)#保存成一个png格式的图片，不同颜色值来代表不同的分类；\n",
    "                                                                     #我们需要将这个颜色的分类映射到我们的某一个\n",
    "                                                                     #class的分类上去，至少分配一个id号，方便回归的时候进行计算，可能变成onehot的向量计算交叉熵\n",
    "                                                            #只用一张图，我们只需要区别是前景类还是背景类\n",
    "                                                          #其实是一张3通道的，这里强行变成单通道的，channels=1\n",
    "class_labels_tenssor = tf.greater_equal(annotation_tensor, 1)#考察每个像素是否>=1,若>=1,说明属于一个分类的grouth truth,代表着某一个object,\n",
    "background_labels_tensor = tf.less_equal(annotation_tensor, 1)#小于1 就说明是0，就会把它置为1，其余的把它简单粗暴的置为前景类，\n",
    "\n",
    "bit_mask_class = tf.to_float(class_labels_tenssor)#前景和背景，\n",
    "bit_mask_background = tf.to_float(background_labels_tensor)\n",
    "\n",
    "combined_mask = tf.concat(axis= 2, values=[bit_mask_background, bit_mask_class])#这两张groud truth的矩阵我们把它对列起来，形成3阶的张量，\n",
    "                                                                                #groud truth的，一般都分成长，宽，以及深\n",
    "                                                                           #在第3个维度上，深度这个维度上可以看做是一个onehot表示，\n",
    "                           #深度上的每一个维度，每一个矩阵都代表着某一个分类，是一个扩展了的onehot的形式，\n",
    "                           #换一个角度去理解的话，长*宽*深，上面的每一个像素点，如果在深度上去截取的话，我们可以得到1*1*classNum,这样的一个1阶向量，\n",
    "                        #换句话说，我们在空间上维度上截取的每一个像素点，在深度上都是一个标准的onehot的向量，就像我们之前讲过的那样\n",
    "                       \n",
    "flat_labels = tf.reshape(combined_mask, shape=(-1,2))\n",
    "#接下来我们从Vgg的vgg_preprocessing里面获取预处理的一些参数\n",
    "from preprocessing.vgg_preprocessing import (_mean_image_subtraction, _R_MEAN, _G_MEAN, _B_MEAN)#比如说整张图片减均值，分成RGB三个通道，这也是Vgg处理里面比较有特色的一个地方\n",
    "                                                                                                 #它计算了RGB上面的平均值，全局上减均值这样的一个操作\n",
    "upsample_factor = 32#上采样倍数，按上述的在空间尺度上缩减为原来的32分之1，所以设为32，\n",
    "number_of_classes = 2#由于刚才的简单粗暴的操作，他变成了2个分类，前景和背景类，像检测中那样，我们的分割也是要多出一个背景分类的，一般将其置为0\n",
    "log_folder = os.path.expanduser('D:/source/models/research/segment_log_folder')\n",
    "#后面就是一系列的预处理操作，包括减均值等等\n",
    "vgg_checkpoint_path = os.path.join(checkpoints_dir,'vgg_16.ckpt')\n",
    "\n",
    "image_float = tf.to_float(image_tensor, name=\"ToFloat\")\n",
    "original_shape = tf.shape(image_float)[0:2]\n",
    "\n",
    "mean_center_image = _mean_image_subtraction(image_float,[_R_MEAN,_G_MEAN,_B_MEAN])\n",
    "\n",
    "#特殊操作，需要去计算在空间尺度上是不是能够被32整除，如果不能整除的话，我们需要对他进行处理，现在网络能够接受任意大小的输入\n",
    "#如果输入是224的倍数的还好，若果不能被32整除的话，不断进行池化，肯定会遇到每一次除以2，中间粗来一个小数，会向下取整\n",
    "#解决方法是补0，补齐到大于32的整数倍，再下面就是初始化\n",
    "target_input_size_factor = tf.ceil(tf.div(tf.to_float(original_shape),tf.to_float(upsample_factor)))\n",
    "\n",
    "                                \n",
    "target_input_size = tf.to_int32(tf.multiply(target_input_size_factor,upsample_factor))\n",
    "padding_size = (target_input_size - original_shape)//2\n",
    "\n",
    "mean_center_image = tf.image.pad_to_bounding_box(mean_center_image, padding_size[0], padding_size[1], target_input_size[0], target_input_size[1])\n",
    "processed_image = tf.expand_dims(mean_center_image, 0)\n",
    "\n",
    "upsample_filter_np = bilinear_upsample_weights(upsample_factor, number_of_classes)\n",
    "\n",
    "upsample_filter_tensor = tf.Variable(upsample_filter_np, name='vgg_16/fc8/t_conv')\n",
    "\n",
    "with slim.arg_scope(vgg.vgg_arg_scope()):#调用slim下面的nets,来在我们下面的计算图中构建一个Vgg16的这样一个网络\n",
    "    logits, end_points = vgg.vgg_16(processed_image, \n",
    "                                    num_classes = 2, #这个参数传递进去之后，就会在我们的Vgg网络里面，fc8里面生成一个\n",
    "                                                     # 分类数为2的这样一个卷积，可能是1*1的kernel,对最后一层4096这样长度的feature map进行卷积，\n",
    "                                                      #在深度上是2，即用两个不同的kernel对他进行卷积，\n",
    "                                                 \n",
    "                                    is_training = is_training_placeholder, #看一下是不是train,涉及到dropout是不是要抛弃的问题，\n",
    "                                    spatial_squeeze = False, #这个是对最后得到的结果logits进不进行squeeze,最后得到了一个1*1*4096，进行squeeze，那么就变成1阶向量\n",
    "                                    fc_conv_padding = 'SAME')#最后做7*7padding的时候，我们采用的方式，用这个会保持7*7，valid就变成1*1）\n",
    "                                                      #现在得益于slim里面的vgg,已经被改成了全卷积的形式，如果不改变输入尺度的情况下，他们可以看做是\n",
    "                                                      #完全等价的，所以不论是在训练，还是在获取权重方面，都没有\n",
    "                                                      #任何的问题，前人设计的时候就会可能就考虑要用在物体检测方面的等\n",
    "downsampled_logits_shape = tf.shape(logits)\n",
    "\n",
    "upsampled_logits_shape = tf.stack([downsampled_logits_shape[0], \n",
    "                                   original_shape[0],\n",
    "                                   original_shape[1],\n",
    "                                   downsampled_logits_shape[3]])\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "#8s factor code\n",
    "'''\n",
    "pool3_feature = end_points['vgg_16/pool3']\n",
    "with tf.variable_scope('vgg_16/fc8'):\n",
    "    aux_logits_8s = slim.conv2d(pool3_feature, \n",
    "                                     2, \n",
    "                                     [1,1], \n",
    "                                     activation_fn=None, \n",
    "                                     weights_initializer =tf.zeros_initializer,\n",
    "                                     scope='conv_pool3')\n",
    "\n",
    "upsample_filter_np_x4 = bilinear_upsample_weights(4, number_of_classes)\n",
    "upsample_filter_tensor_x4 = tf.Variable(upsample_filter_np_x4, name=\"vgg_16/fc8/t_conv_x4\")\n",
    "\n",
    "#进行一个上采样\n",
    "upsampled_logits = tf.nn.conv2d_transpose(logits,    \n",
    "                                          upsample_filter_np_x4,\n",
    "                                          output_shape=tf.shape(aux_logits_8s), \n",
    "                          \n",
    "                                          strides=[1,4,4,1],\n",
    "                                          padding='SAME')\n",
    "\n",
    "upsampled_logits = upsampled_logits + aux_logits_8s\n",
    "\n",
    "upsample_filter_np_x8 = bilinear_upsample_weights(upsample_factor, number_of_classes)\n",
    "\n",
    "upsample_filter_tensor_x8 = tf.Variable(upsample_filter_np_x8, name='vgg_16/fc8/t_conv_x8')\n",
    "\n",
    "upsampled_logits = tf.nn.conv2d_transpose(upsampled_logits, upsample_filter_tensor_x8, output_shape=upsampled_logits_shape,\n",
    "                                         strides=[1, upsample_factor, upsample_factor, 1], padding='SAME')\n",
    "'''\n",
    "\n",
    "#16s factor code\n",
    "'''\n",
    "\n",
    "pool4_feature = end_points['vgg_16/pool4']#pool4的结果\n",
    "with tf.variable_scope('vgg_16/fc8'):\n",
    "    aux_logits_16s = slim.conv2d(pool4_feature, #定义新的1*1的卷积，对pool4的feature进行一个2分类，\n",
    "                                2, \n",
    "                                [1,1], \n",
    "                                activation_fn=None,#不需要激活函数\n",
    "                                weights_initializer =tf.zeros_initializer,#刚开始全部为0\n",
    "                                scope='conv_pool4')\n",
    "\n",
    "upsample_filter_np_x2 = bilinear_upsample_weights(2, number_of_classes)\n",
    "\n",
    "upsample_filter_tensor_x2 = tf.Variable(upsample_filter_np_x2, name=\"vgg_16/fc8/t_conv_x2\")\n",
    "\n",
    "upsampled_logits = tf.nn.conv2d_transpose(logits,#对logits的结果进行一个2倍的上采样，（反卷积，转置卷积）\n",
    "                                         upsample_filter_np_x2, \n",
    "                                         output_shape=tf.shape(aux_logits_16s), #得到的卷积核pool4的尺寸完全一样的\n",
    "                                          strides=[1,2,2,1], \n",
    "                                          padding='SAME')\n",
    "\n",
    "upsampled_logits = upsampled_logits + aux_logits_16s#将ool4分类的结果和我们最终的分类结果进行上采样之后的结果进行加和，\n",
    "\n",
    "upsample_filter_np_x16 = bilinear_upsample_weights(upsample_factor, number_of_classes)\n",
    "\n",
    "upsample_filter_tensor_x16 = tf.Variable(upsample_filter_np_x16, name='vgg_16/fc8/t_conv_x16')#再对转置卷积进行16倍的上采样\n",
    "\n",
    "upsampled_logits = tf.nn.conv2d_transpose(upsampled_logits, \n",
    "                                          upsample_filter_tensor_x16,\n",
    "                                          output_shape=upsampled_logits_shape,\n",
    "                                         strides=[1, upsample_factor, upsample_factor, 1], \n",
    "                                         padding='SAME')\n",
    "'''\n",
    "\n",
    "#32s factor code\n",
    "upsampled_logits = tf.nn.conv2d_transpose(logits, #logits作为input\n",
    "                                          upsample_filter_tensor, #上采样需要用到的kernel,双线性插值得到的kernel \n",
    "                                          output_shape=upsampled_logits_shape,#我们希望得到的精确的输出的大小，\n",
    "                                                                                #到底是多大，在caffe里面是指定位置裁切\n",
    "                                                                                #在tf里面已经把这一步整合到方法里面去，只需要告诉他\n",
    "                                                                                #精确的大小是多大，就会在相应的位置计算出来，并进行裁切，方便\n",
    "                                          strides=[1, upsample_factor, upsample_factor, 1],#strides=[1,32,32,1]是在缩小了32倍的这样的空间\n",
    "                                                                                           #尺度进行反卷积，所以是32\n",
    "                                          padding='SAME')\n",
    "#transpose_convaltion是一个卷积的反向运算，用valid，，相比same的话，要更大一些"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "flat_logits = tf.reshape(tensor=upsampled_logits, shape=(-1, number_of_classes))#将原图等大的logits，展平，\n",
    "\n",
    "#和groud truth做交叉熵的计算，（双线性插值结合到我们的卷积神经网络中），目的就是方便的进行前向和反向传播，来计算我们的\n",
    "#我们 的交叉熵，\n",
    "cross_entropies = tf.nn.softmax_cross_entropy_with_logits(logits = flat_logits, labels=flat_labels)\n",
    "#下面就和普通的神经网络的优化过程相差不多，\n",
    "cross_entropy_sum = tf.reduce_sum(cross_entropies)\n",
    "\n",
    "pred = tf.argmax(upsampled_logits, axis=3)#再打概率分类类别，\n",
    "\n",
    "probabilities = tf.nn.softmax(upsampled_logits)#得到每一个像素具体法人概率的输出\n",
    "\n",
    "with tf.variable_scope(\"adam_vars\"):\n",
    "    optimizer = tf.train.AdamOptimizer(learning_rate=1e-4)#AdamOptimizer优化器，对我们的损失进行迭代的优化\n",
    "    gradients = optimizer.compute_gradients(loss=cross_entropies)#这里没有采用我们的minimize的方法，\n",
    "                                                                  #首先计算出他的梯度\n",
    "\n",
    "    for grad_var_pair in gradients:\n",
    "        current_variable = grad_var_pair[1]\n",
    "        current_gradient = grad_var_pair[0]\n",
    "\n",
    "        gradient_name_to_save = current_variable.name.replace(\":\",\"_\")\n",
    "        tf.summary.histogram(gradient_name_to_save, current_gradient)#之所以分两步是因为，是要将梯度输入到我们的额\n",
    "                                                                  #summary里面去，方便我们可视化的观察\n",
    "\n",
    "    train_step = optimizer.apply_gradients(grads_and_vars=gradients)#在对其进行应用梯度，用来更新我们所有的的权重\n",
    "\n",
    "#获取Vgg里面一些变量的定义\n",
    "#主要是恢复预训练的 pre cheng 的一些权重的时候，也就是imagenet训练Vgg的得到的这个checkpoint的，其中记录的预训练的\n",
    "#的权重，用来初始化我们的模型，要排除一些权重，fc8是分类器，4096*1001，现在只有两个分类，所以情况不同了，所以将其排除\n",
    "#所以在restore的时候要排除掉，这里没有restore，而只是获取了将要restore的权重的列表，因为采用了Adam优化算法\n",
    "#所以有一些变量，在原来的checkpoint里面是没有的，比如adam用到的“B”->\"贝塔\"，adam_vars，\n",
    "vgg_expect_fc8_weights = slim.get_variables_to_restore(exclude=['vgg_16/fc8', 'adam_vars'])\n",
    "\n",
    "\n",
    "vgg_fc8_weithts = slim.get_variables_to_restore(include=['vgg_16/fc8'])#restore的时候从其中去掉了，我们不对他进行restore\n",
    "                                                                    #那么我们就需要有一个方式对他进行初始化，\n",
    "adam_optimizer_variables = slim.get_variables_to_restore(include=['adam_vars'])#这个和上面的vgg_fc8_weithts\n",
    "                                                                          #都是需要初始化的参数，fc8和adma相关参数\n",
    "                                                                        #adam_optimizer_variables，vgg_fc8_weithts将其保存到列表里面去\n",
    "tf.summary.scalar('cross_entropy_loss', cross_entropy_sum)#summary就是记录一些操作\n",
    "\n",
    "merged_summary_op = tf.summary.merge_all()\n",
    "\n",
    "summary_string_writer = tf.summary.FileWriter(log_folder)#这里没有采用像之前一样先定义一个server,将session作为参数传递进去的\n",
    "                                     #方式，而是将他通过系统调用的方式做成一个op,之后只需要用session去run这个欧赔，就可以了，\n",
    "\n",
    "if not os.path.exists(log_folder):\n",
    "    os.makedirs(log_folder)\n",
    "\n",
    "read_vgg_weights_except_fc8_func = slim.assign_from_checkpoint_fn(vgg_checkpoint_path, vgg_expect_fc8_weights)\n",
    "\n",
    "vgg_fc8_weight_initializer = tf.variables_initializer(vgg_fc8_weithts)#下面用到了之后要用到的初始化方法，\n",
    "\n",
    "optimization_variables_initializer = tf.variables_initializer(adam_optimizer_variables)\n",
    "\n",
    "init_op = tf.global_variables_initializer()\n",
    "sess_config = tf.ConfigProto()#这3行就是定义的seession\n",
    "sess_config.gpu_options.allow_growth = True\n",
    "sess = tf.Session(config= sess_config)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "with sess:#对 session 进行一系列的处理，\n",
    "    sess.run(vgg_fc8_weight_initializer)#运行初始化方法\n",
    "    sess.run(optimization_variables_initializer)\n",
    "    read_vgg_weights_except_fc8_func(sess)#运行restore的这个op，\n",
    "    \n",
    "    #下面就是一些常规的训练过程以及\n",
    "    train_image, train_annotation = sess.run([image_tensor, annotation_tensor], feed_dict=feed_dict_to_use)\n",
    "                                                     #训练过程和训练结果，image_tensor在计算图中定义的decode中的image\n",
    "                                                    #以及相应的annotation_tensor 这个ground truth\n",
    "            \n",
    "            #现在我们把decode之后的train_image，train_annotation的numpy的矩阵的方式，存在的数据取出来，\n",
    "\n",
    "    f, (ax1, ax2) = plt.subplots(1, 2)#用于显示，首先将图片和它用到的图片和ground truth显示出来，用matplot显示出来\n",
    "    ax1.imshow(train_image)\n",
    "    ax1.set_title('Input Image')\n",
    "    probability_graph = ax2.imshow(np.dstack((train_annotation,)*3)*100)\n",
    "    ax2.set_title('Input Ground-Truth Annotation')\n",
    "    plt.show()\n",
    "\n",
    "    downsample_logits_value, train_annotation = sess.run([downsampled_logits_shape, annotation_tensor], feed_dict=feed_dict_to_use)\n",
    "\n",
    "    print(downsampled_logits_shape.shape)\n",
    "    \n",
    "    #下面这个一是只进行了10次，再就是说，得到的probabilities太过粗糙，他是进过1/32的降采样之后，再通过一个升采样之后，\n",
    "    #再还原回来，所以结果比较粗糙，\n",
    "    for i in range(10):#只迭代了10个step,每次只用这一张图片训练，每个step都会把中间的结果进行输出\n",
    "        loss, summary_string = sess.run([cross_entropy_sum, merged_summary_op], feed_dict=feed_dict_to_use)\n",
    "        sess.run(train_step, feed_dict=feed_dict_to_use)#上方对交叉熵进行计算，这里输出数值，然后这行run一下，把我们的梯度更新上去，\n",
    "        pred_np, probabilities_np = sess.run([pred, probabilities], feed_dict=feed_dict_to_use)#获取我们当前预测的，这样的一个预测，也就是之前只\n",
    "                        #用argmax计算出的每一个像素，他的最大可能的分类，是前景还是背景，因为只有2个分类，以及某一个\n",
    "                      #像素他是前景的具体的概率，\n",
    "\n",
    "        summary_string_writer.add_summary(summary_string, i)\n",
    "\n",
    "        cmap = plt.get_cmap('bwr')                           \n",
    "        \n",
    "        #下面这个可视化都是我们对他进行的图像化输出的方式\n",
    "        f, (ax1, ax2, ax3) = plt.subplots(1, 3)\n",
    "        ax1.imshow(np.uint8(pred_np.squeeze() != 1), vmax=1.5, vmin=-0.4, cmap=cmap)\n",
    "        ax1.set_title('Argmax. Iteration #' + str(i))\n",
    "        probability_graph = ax2.imshow(probabilities_np.squeeze()[:,:,0])\n",
    "        ax2.set_title('Probability of the class. Iteration #' + str(i))\n",
    "        mask = np.multiply(np.uint32(pred_np.squeeze()), 128)\n",
    "        mask = np.stack([mask,]*3, axis=-1)\n",
    "        masked_image = np.uint8(np.clip(train_image+mask, 0, 255))\n",
    "        probability_graph = ax3.imshow(masked_image)\n",
    "        plt.colorbar(probability_graph)\n",
    "        plt.show()\n",
    "        print(\"Current Loss \" + str(loss))\n",
    "\n",
    "    feed_dict_to_use[is_training_placeholder] = False#最后我们完成了这10个step的迭代之后呢，\n",
    "\n",
    "    final_predictions, final_probabilities, final_loss = sess.run([pred,#我们再次输出最后的这些信息\n",
    "                                                                   probabilities,cross_entropy_sum], \n",
    "                                                                  feed_dict=feed_dict_to_use)\n",
    "    f, (ax1,ax2,ax3) = plt.subplots(1, 3)\n",
    "\n",
    "    ax1.imshow(np.uint8(final_predictions.squeeze()!=1), vmax=1.5,vmin=-0.4, cmap=cmap)\n",
    "    ax1.set_title(\"Final Argmax\")\n",
    "\n",
    "    probability_graph = ax2.imshow(final_probabilities.squeeze()[:,:,0])\n",
    "    ax2.set_title('Final Probability of the Class')\n",
    "    plt.colorbar(probability_graph)\n",
    "\n",
    "    mask = np.multiply(np.uint32(final_predictions.squeeze()), 128)\n",
    "    mask = np.stack([np.zeros(mask.shape),mask, np.zeros(mask.shape)], axis=-1)\n",
    "\n",
    "    masked_image =np.uint8(np.clip(train_image+mask, 0, 255))\n",
    "    probability_graph = ax3.imshow(masked_image)\n",
    "\n",
    "    plt.show()\n",
    "    print(\"Final Loss: \" + str(final_loss))\n",
    "\n",
    "summary_string_writer.close()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "CRF handling"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#CRF handling\n",
    "\n",
    "import pydensecrf.densecrf as dcrf\n",
    "\n",
    "from pydensecrf.utils import compute_unary, create_pairwise_bilateral, create_pairwise_gaussian, softmax_to_unary\n",
    "\n",
    "img = train_image\n",
    "\n",
    "processed_probabilities = final_probabilities.squeeze()#把最后输出的这样一个可能性置信度的图进行squeese\n",
    "\n",
    "softmax = processed_probabilities.transpose((2,0,1))#transpose是调整他的形态，相当于把softmax，最后输出的onehot的那个depth\n",
    "                                          #调整到前面去，满足我们这个库所需要的数据格式，\n",
    "\n",
    "unary = softmax_to_unary(softmax)#将其向量化\n",
    "\n",
    "print(unary.shape)\n",
    "\n",
    "unary = np.ascontiguousarray(unary)#改成c连续的这样一个数组，其实是对他在内存中的这样一个形态进行调整，\n",
    "\n",
    "d = dcrf.DenseCRF(image.shape[0] * image.shape[1], 2)#计算出这样密集场，图片的image.shape[0]长和image.shape[1]宽，和相应的通道数2\n",
    "\n",
    "d.setUnaryEnergy(unary)#得到我们的一元势能函数，也就是我们最后得到的俄softmax结果作为势能函数的输入\n",
    "\n",
    "\n",
    "#下面就是采用近似的方法，计算他的二元势能函数，如果采用这样的一个全连接方式的话，计算每两个像素点之间的关系，\n",
    "feats = create_pairwise_gaussian(sdims=(10,10), shape=image.shape[:2])#采用高斯近似的方式计算他的二元势能函数，\n",
    "\n",
    "\n",
    "d.addPairwiseEnergy(feats, compat=3, kernel=dcrf.DIAG_KERNEL, normalization=dcrf.NORMALIZE_SYMMETRIC)\n",
    "\n",
    "feats = create_pairwise_bilateral(sdims=(50,50), schan=(20,20,20), img = image, chdim=2)\n",
    "\n",
    "d.addPairwiseEnergy(feats. compat=10, kernel=dcrf.DIAG_KERNEL, normalization=dcrf.NORMALIZA_SYMMETRIC)\n",
    "\n",
    "Q = d.inference(5)\n",
    "\n",
    "res = np.argmax(Q, axis=0).reshape((image.shape[0], image.shape[1]))#得到最后的结果，\n",
    "#下面就是进行matplot输出，\n",
    "cmap = plt.get_cmap('bwr')\n",
    "\n",
    "f, (ax1, ax2, ax3) = plt.subplots(1,3)\n",
    "\n",
    "ax1.imshow(res, vmax=1.5, vmin=-0.4, cmap=cmap)\n",
    "ax1.set_title('Segmentation with CRF post-processing')\n",
    "\n",
    "probobality_graph = ax2.imshow(np.dstack((train_annotation,)*3)*100)\n",
    "ax2.set_title('Ground Truth Annotation')\n",
    "\n",
    "mask = np.multiply(np.uint32(res.queeze()),128)\n",
    "mask = np.stack([np.zeros(mask.shape),mask,np.zeros(mask.shape)], axis=-1)\n",
    "\n",
    "masked_image = np.uint8(np.clip(np.uint32(train_image) + mask, 0, 255))\n",
    "\n",
    "probability_grapth = ax3.imshow(masked_image)\n",
    "\n",
    "plt.show()\n",
    "\n",
    "intersection = np.logical_and(res, train_annotation.squeeze())\n",
    "\n",
    "union = np.logical_or(res, train_annotation.squeeze())\n",
    "\n",
    "sum_intersection = np.sum(intersection)\n",
    "\n",
    "sum_union = np.sum(union)\n",
    "\n",
    "print(\"IoU:%.2f%%\" % ((sum_intersection/sum_union)*100))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
